from matplotlib import pyplot as plt
from scipy.misc import imresize
from os import listdir
import math
import numpy as np
import cv2
import random
from time import time

INF = 9999999999999
f_zero = 1e-10


class Timer:

    def __init__(self, as_progress_notifier=True):
        """
        If set as_progress_notifier = True, then it will be use to check progress of some processes.
        If not it will be use to simple timer.
        :param as_progress_notifier:
        """
        self.whole_number_of_data = 0
        self.current_progress = 0
        self.time_moving_average = 0
        self.elapsed_time = 0
        self.remain_time = 0
        self.tick_start_time = 0
        self.is_progress_notifier = as_progress_notifier
        self.timer_ready = False
        self.print_fn = None

    def start(self, number_of_data=None):
        if self.is_progress_notifier:
            if number_of_data is None:
                raise ValueError('You should feed number_of_data.')
            self.whole_number_of_data = number_of_data
            self.current_progress = 0
            self.timer_ready = True
            self.tick_start_time = time()
        else:
            self.tick_start_time = time()
            self.timer_ready = True

    def tick_timer(self):
        if not self.timer_ready:
            raise AttributeError('Need to initialize timer by init_timer().')
        if not self.is_progress_notifier:
            raise AttributeError('You should set as_progress_notifier to True if you want to use tick_timer().')

        if self.current_progress == 0:
            self.time_moving_average = time() - self.tick_start_time
        else:
            tick = time() - self.tick_start_time
            self.elapsed_time = self.current_progress * self.time_moving_average
            self.time_moving_average = (self.elapsed_time + tick) / (self.current_progress + 1)

        self.current_progress += 1
        self.remain_time = (self.whole_number_of_data - self.current_progress) * self.time_moving_average

        if self.print_fn is not None:
            self.print_fn()

        self.tick_start_time = time()

    def check(self, start_timer=True):
        if self.is_progress_notifier:
            raise AttributeError('You should set as_progress_notifier to False if you want to use check().')
        self.elapsed_time = time() - self.tick_start_time

        if self.print_fn is not None:
            self.print_fn()

        if start_timer:
            self.start(self.whole_number_of_data)


def random_integers(min, max, size, exclude: list=None):
    if max < min:
        raise ValueError('Max must be larger than min')
    if exclude is not None:
        if (max - min) >= size + len(exclude):
            rand_list = list(random.sample(range(min, max + 1), size + len(exclude)))
        else:
            rand_list = range(min, max + 1)
        excluded_rand_list = [r for r in rand_list if r not in exclude]
        choices = np.random.choice(excluded_rand_list, size, replace=True if len(excluded_rand_list) < size else False)
    else:
        choices = list(random.sample(range(min, max + 1), size))

    return choices


def random_select(data_list: list, size, exclude: list=None, return_unselected=False):
    """
    Randomly select "size" elements from list
    :param data_list: Original list to select
    :param size: Number of elements to select
    :param exclude: Elements index to exclude
    :param return_unselected: Return excluded elements
    :return: Return selected elements and unselected elements (if return_unselected is True)
    """
    selection = random_integers(0, len(data_list) - 1, size, exclude)
    selected = [data_list[i] for i in selection]
    unselected = list()
    for idx, d in enumerate(data_list):
        if idx not in selection:
            unselected.append(d)

    if return_unselected:
        return selected, unselected
    else:
        return selected


def is_image(file_name):
    file_extension = file_name.split('.')[-1]
    if file_extension in ('jpg', 'JPG', 'jpeg', 'JPEG', 'tif', 'TIF', 'png', 'PNG', 'bmp', 'BMP'):
        return True
    else:
        return False


def resize_images(path, size):
    files = listdir(path)
    n_files = len(files)
    for idx, f in enumerate(files):
        if 1500 <= idx < 2000:
            if is_image(f):
                print('[%d/%d]' % (idx + 1, n_files))
                img = imresize(
                    plt.imread(path + '/' + f),
                    size
                )
                plt.imsave(
                    path + '/' + f,
                    img
                )


def resize_image(file, size):
    if is_image(file):
        img = imresize(
            plt.imread(file),
            size
        )
        plt.imsave(
            file,
            img
        )


def extract_roi(image, img_size: tuple, roi_size: tuple):
    """
    Resize image by img_size and crop Region Of Interest
    :param image: Input image
    :param img_size: Size for resizing
    :param roi_size: Size of ROI
    :return: ROI image
    """
    image = imresize(image, img_size)
    image = image[
          int(img_size[0] * 0.7765):int(img_size[0] * 0.8365),
          int(img_size[1] * 0.235):int(img_size[1] * 0.765),
          :3
          ]
    # image = image[
    #       int(img_size[0] * 0.7365):int(img_size[0] * 0.8365),
    #       int(img_size[1] * 0.235):int(img_size[1] * 0.765),
    #       :3
    #       ]
    image = imresize(image, roi_size)
    return image


def remove_remains(img, interest_point):
    """
    Remove remains which are not adjacent with interest_point
    :param img: Input image
    :param interest_point: Center point where we want to remain
    :return: Image which adjacent with interest_point
    """
    img = img.astype(np.uint8)
    h, w = img.shape[:2]
    mask = np.zeros((h + 2, w + 2), np.uint8)

    img_inv = img.copy()
    cv2.floodFill(img_inv, mask, tuple(interest_point), 0)
    img -= img_inv

    return img


def fill_hole_and_remove_remains(img, interest_point):
    """
    If interest point is hole, then fill hole and remove remains of image
    :param img: Input image
    :param interest_point: Center point where we want to remain
    :return: Image which adjacent with interest_point
    """
    _img = remove_remains(
        img=img,
        interest_point=interest_point
    )
    img = img + _img
    img[img > 0] = 1
    img = remove_remains(
        img=img,
        interest_point=interest_point
    )

    return img


def slope(point1, point2):
    point1 = f_cut(point1)
    point2 = f_cut(point2)
    if point1[0] == point2[0]:
        return INF
    return (point1[1] - point2[1]) / (point1[0] - point2[0])


def img_center(img, bbox=None, show=False):  # Bbox have left top and right bottom position
    """
    Find center of image pixels which is not 0 and bounding box
    :param img: Input image (1 channel image)
    :param bbox: Bounding box for image
    :param show: Show result of function
    :return: Center of image element and bounding box
    """
    img_height, img_width = img.shape[:2]
    sum = np.array([0, 0])
    num = 0
    left = img_width
    right = 0
    top = img_height
    bottom = 0

    if bbox is None:
        for y in range(img_height):
            for x in range(img_width):
                if img[y][x] > 0:
                    sum += np.array([x, y])
                    num += 1
                    if x < left:
                        left = x
                    if x > right:
                        right = x
                    if y < top:
                        top = y
                    if y > bottom:
                        bottom = y
    else:
        for y in range(bbox[0][1], bbox[1][1]):
            for x in range(bbox[0][0], bbox[1][0]):
                if img[y][x] > 0:
                    sum += np.array([x, y])
                    num += 1
                    if x < left:
                        left = x
                    if x > right:
                        right = x
                    if y < top:
                        top = y
                    if y > bottom:
                        bottom = y

    center = (np.round(sum / num)).astype(np.int32)

    if show:
        img_plane = img.astype(np.uint8)
        cv2.circle(
            img=img_plane,
            center=tuple(center),
            radius=3,
            color=2,
            thickness=1
        )
        plt.imshow(img_plane)
        plt.show()

    return center, bbox if bbox is not None else [[left, top], [right, bottom]]


def round_int(num):
    if type(num) in [list, tuple, dict]:
        return type(num)(np.round(num).astype(np.int32))
    else:
        return np.round(num).astype(np.int32)


def np_cast(ary, t):
    return ary.astype(t)


def f_cut(x):
    if type(x) is float:
        return 0 if math.fabs(x) <= f_zero else x
    elif type(x) is list:
        x = np.array(x)
        mask = np.logical_not(np.fabs(x) <= f_zero)
        return list(x * mask)
    elif type(x) is np.ndarray:
        mask = np.logical_not(np.fabs(x) <= f_zero)
        return x * mask
    else:
        return x


def vector_length(x: np.ndarray):
    sum = 0
    for i in range(len(x)):
        sum += math.pow(x[i], 2)
    return math.sqrt(sum)


def voc_color_map():

    def bit_at(_int: int, idx):
        return (_int & (1 << idx)) != 0

    n = 256
    cmap = list(np.zeros([n, 3], dtype=np.uint8))
    for i in range(n):
        r = g = b = 0
        c = i
        for j in range(8):
            r = r | ((bit_at(c, 0)) << (7 - j))
            g = g | ((bit_at(c, 1)) << (7 - j))
            b = b | ((bit_at(c, 2)) << (7 - j))
            c = c >> 3
        cmap[i] = [r, g, b]

    return cmap


def voc_label_to_color(label):
    cmap = list(voc_color_map())
    image = np.tile(label, [3, 1, 1])
    for c in range(21):
        for i in range(3):
            image[i][image[i] == c] = cmap[c][i]

    image = np.transpose(image, [1, 2, 0]).astype(np.uint8)
    return image
